package org.test4j.integration.junit5;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.springframework.test.context.TestContextManager;
import org.test4j.module.spring.SpringEnv;

import java.lang.reflect.Method;

import static org.test4j.mock.faking.util.ReflectUtility.doThrow;

/**
 * @author darui.wu
 * @date 2019/11/5 4:16 下午
 */
public class JUnit5SpringHelper {
    /**
     * {@link Namespace} in which {@code TestContextManagers} are stored,
     * keyed by test class.
     */
    private static final Namespace NAMESPACE = Namespace.create(JUnit5Extension.class);

    public static Store getStore(ExtensionContext context) {
        return context.getRoot().getStore(NAMESPACE);
    }

    /**
     * Get the {@link TestContextManager} associated with the supplied {@code ExtensionContext}.
     *
     * @return the {@code TestContextManager} (never {@code null})
     */
    public static TestContextManager getTestContextManager(ExtensionContext context) {
        Assertions.assertNotNull(context, "ExtensionContext must not be null");
        Class<?> testClass = context.getRequiredTestClass();
        Store store = getStore(context);
        return store.getOrComputeIfAbsent(testClass, TestContextManager::new, TestContextManager.class);
    }

    public static void beforeAll(ExtensionContext context) {
        SpringEnv.setSpringEnv(context.getRequiredTestClass());
        try {
            getTestContextManager(context).beforeTestClass();
        } catch (Exception e) {
            doThrow(e);
        }
    }

    public static void afterAll(ExtensionContext context) throws Exception {
        try {
            if (SpringEnv.isSpringEnv()) {
                getTestContextManager(context).afterTestClass();
            }
        } finally {
            getStore(context).remove(context.getRequiredTestClass());
        }
    }

    public static void beforeMethod(Object target) {
        SpringEnv.injectSpringBeans(target);
    }

    public static void beforeEach(ExtensionContext context) throws Exception {
        Object target = context.getTestInstance().orElse(null);
        Method testMethod = context.getRequiredTestMethod();
        if (SpringEnv.isSpringEnv(context.getRequiredTestClass())) {
            getTestContextManager(context).beforeTestMethod(target, testMethod);
        }
    }

    public static void afterEach(ExtensionContext context) throws Exception {
        Object testInstance = context.getRequiredTestInstance();
        Method testMethod = context.getRequiredTestMethod();
        Throwable testException = context.getExecutionException().orElse(null);
        if (SpringEnv.isSpringEnv(context.getRequiredTestClass())) {
            getTestContextManager(context).afterTestMethod(testInstance, testMethod, testException);
        }
    }
}